package com.mch.ar.core;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.ServiceLoader;
import java.util.Set;

import javax.sql.DataSource;

import com.mch.ar.dialect.Dialect;
import com.mch.ar.exce.DBOpenException;
import com.mch.ar.exce.IllegalTableNameException;
import com.mch.ar.exce.SqlExecuteException;
import com.mch.ar.exce.TransactionException;
import com.mch.ar.exce.UnsupportedDatabaseException;
import com.mch.ar.pool.SingletonDataSource;
import com.mch.utils.Seq;

/**
 * 数据库对象。
 * 
 * @since 1.0
 * @author redraiment
 * @mofifier MCHWEB.NET
 */
public final class DB {
	private static final ServiceLoader<Dialect> dialects;
	public static boolean isDev = true;

	static {
		dialects = ServiceLoader.load(Dialect.class);
	}

	public static DB open(String url) {
		return open(url, new Properties());
	}

	public static DB open(String url, String username, String password) {
		Properties info = new Properties();
		info.put("user", username);
		info.put("password", password);
		return open(url, info);
	}

	public static DB open(String url, Properties info) {
		try {
			return open(new SingletonDataSource(url, info));
		} catch (SQLException e) {
			throw new DBOpenException(e);
		}
	}

	public static DB open(DataSource pool) {
		try (Connection base = pool.getConnection()) {
			for (Dialect dialect : dialects) {
				if (dialect.accept(base)) {
					base.close();
					return new DB(pool, dialect);
				}
			}

			DatabaseMetaData meta = base.getMetaData();
			String version = String.format("%s %d.%d/%s", meta.getDatabaseProductName(), meta.getDatabaseMajorVersion(),
					meta.getDatabaseMinorVersion(), meta.getDatabaseProductVersion());
			throw new UnsupportedDatabaseException(version);
		} catch (SQLException e) {
			throw new DBOpenException(e);
		}
	}

	private final DataSource pool;
	private final InheritableThreadLocal<Connection> base;
	private final Dialect dialect;
	private final Map<String, Map<String, Integer>> columns;
	private final Map<String, Map<String, Association>> relations;
	private final Map<String, Map<String, Lambda>> hooks;

	private DB(DataSource pool, Dialect dialect) {
		this.pool = pool;
		this.base = new InheritableThreadLocal<>();
		this.columns = new HashMap<>();
		this.relations = new HashMap<>();
		this.dialect = dialect;
		this.hooks = new HashMap<>();
	}

	private Connection getConnection() {
		try {
			return base.get() == null ? pool.getConnection() : base.get();
		} catch (SQLException e) {
			throw new DBOpenException(e);
		}
	}

	void close(Connection c) {
		if (c != null && base.get() != c) {
			try {
				c.close();
			} catch (SQLException e) {
				throw new RuntimeException("close Connection fail", e);
			}
		}
	}

	void close(Statement s) {
		if (s != null) {
			try {
				Connection c = s.getConnection();
				s.close();
				close(c);
			} catch (SQLException e) {
				throw new RuntimeException("close Statement fail", e);
			}
		}
	}

	void close(ResultSet rs) {
		if (rs != null) {
			try {
				Statement s = rs.getStatement();
				rs.close();
				close(s);
			} catch (SQLException e) {
				throw new RuntimeException("close ResultSet fail", e);
			}
		}
	}

	public Map<String, String> getTableNames() {
		Map<String, String> tables = new HashMap<String, String>();
		try (Connection c = pool.getConnection()) {
			DatabaseMetaData db = c.getMetaData();
			try (ResultSet rs = db.getTables(null, null, "%", new String[] { "TABLE" })) {
				while (rs.next()) {
					String tableName = rs.getString("table_name");
					ResultSet pkrs = db.getPrimaryKeys(null, null, tableName);
					pkrs.next();
					tables.put(tableName, pkrs.getString("COLUMN_NAME"));
				}
			}
		} catch (SQLException e) {
			throw new DBOpenException(e);
		}
		return tables;
	}

	public Set<Table> getTables() {
		Map<String, String> tbmap = getTableNames();
		List<String> tableNameList = new ArrayList<>(tbmap.values());
		Set<Table> tables = new HashSet<>();
		for (String name : tableNameList) {
			tables.add(active(name, tbmap.get(name)));
		}
		return tables;
	}

	private Map<String, Integer> getColumns(String name) throws SQLException {
		if (!columns.containsKey(name)) {
			synchronized (columns) {
				if (!columns.containsKey(name)) {
					String catalog, schema, table;
					String[] patterns = name.split("\\.");
					if (patterns.length == 1) {
						catalog = null;
						schema = null;
						table = patterns[0];
					} else if (patterns.length == 2) {
						catalog = null;
						schema = patterns[0];
						table = patterns[1];
					} else if (patterns.length == 3) {
						catalog = patterns[0];
						schema = patterns[1];
						table = patterns[2];
					} else {
						throw new IllegalArgumentException(String.format("Illegal table name: %s", name));
					}

					Map<String, Integer> column = new LinkedHashMap<>();
					try (Connection c = pool.getConnection()) {
						DatabaseMetaData db = c.getMetaData();
						try (ResultSet rs = db.getColumns(catalog, schema, table, null)) {
							while (rs.next()) {
								String columnName = rs.getString("column_name");
								if (columnName.equalsIgnoreCase("id") || columnName.equalsIgnoreCase("created_at")
										|| columnName.equalsIgnoreCase("updated_at")) {
									continue;
								}
								column.put(parseKeyParameter(columnName), rs.getInt("data_type"));
							}
						}
					}
					columns.put(name, column);
				}
			}
		}
		return columns.get(name);
	}

	public Table active(String name) {
		return active(name, "id");
	}

	public Table active(String name, String pkFieldName) {
		name = dialect.getCaseIdentifier(name);

		if (!relations.containsKey(name)) {
			synchronized (relations) {
				if (!relations.containsKey(name)) {
					relations.put(name, new HashMap<String, Association>());
				}
			}
		}

		if (!hooks.containsKey(name)) {
			synchronized (hooks) {
				if (!hooks.containsKey(name)) {
					hooks.put(name, new HashMap<String, Lambda>());
				}
			}
		}

		try {
			return new Table(this, name, getColumns(name), relations.get(name), hooks.get(name), pkFieldName);
		} catch (SQLException e) {
			throw new IllegalTableNameException(name, e);
		}
	}

	public PreparedStatement prepare(String sql, Object[] params, int[] types) {
		Connection c = getConnection();
		try {
			PreparedStatement call;
			if (sql.trim().toLowerCase().startsWith("insert")) {
				call = c.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
			} else {
				call = c.prepareStatement(sql);
			}

			if (params != null && params.length > 0) {
				for (int i = 0; i < params.length; i++) {
					if (params[i] != null) {
						call.setObject(i + 1, params[i]);
					} else {
						call.setNull(i + 1, types[i]);
					}
				}
			}
			return call;
		} catch (SQLException e) {
			throw new SqlExecuteException(sql, e);
		}
	}

	public int execute(String sql, Object[] params, int[] types) {
		PreparedStatement call = prepare(sql, params, types);
		debug(sql);

		try {
			return call.executeUpdate();
		} catch (SQLException e) {
			throw new SqlExecuteException(sql, e);
		} finally {
			close(call);
		}
	}

	public int execute(String sql) {
		return execute(sql, null, null);
	}

	public List<Record> queryList(String sql, Object... params) {
		try {
			ResultSet rs = query(sql, params);
			ResultSetMetaData rsmd = rs.getMetaData();
			int columnCount = rsmd.getColumnCount();

			List<Record> list = new ArrayList<>();
			while (rs.next()) {
				Record re = new Record();
				for (int i = 1; i <= columnCount; i++) {
					re.set(rsmd.getColumnName(i), rs.getObject(i));
				}
				list.add(re);
			}

			if (list.size() == 0) {
				return null;
			} else {
				return list;
			}
		} catch (SQLException e) {
			throw new SqlExecuteException(sql, e);
		}
	}

	public ResultSet query(String sql, Object... params) {
		try {
			PreparedStatement call = prepare(sql, params, null);
			debug(sql);

			return call.executeQuery();
		} catch (SQLException e) {
			throw new SqlExecuteException(sql, e);
		}
	}

	public Table createTable(String name, String... columns) {
		String template = "create table %s (id %s, %s, created_at timestamp, updated_at timestamp)";
		execute(String.format(template, name, dialect.getIdentity(), Seq.join(Arrays.asList(columns), ", ")));
		return active(name, "id");
	}

	public Table createTable(String name, String pkFieldName, String... columns) {
		String template = "create table %s (" + pkFieldName + " %s, %s, created_at timestamp, updated_at timestamp)";
		execute(String.format(template, name, dialect.getIdentity(), Seq.join(Arrays.asList(columns), ", ")));
		return active(name, pkFieldName);
	}

	public void dropTable(String name) {
		execute(String.format("drop table if exists %s", name));
	}

	public void createIndex(String name, String table, String... columns) {
		execute(String.format("create index %s on %s(%s)", name, table, Seq.join(Arrays.asList(columns), ", ")));
	}

	public void dropIndex(String name, String table) {
		execute(String.format("drop index %s", name));
	}

	/* Transaction */
	public void batch(Runnable transaction) {
		// TODO: 不支持嵌套事务
		try (Connection c = pool.getConnection()) {
			boolean commit = c.getAutoCommit();
			try {
				c.setAutoCommit(false);
			} catch (SQLException e) {
				throw new TransactionException("transaction setAutoCommit(false)", e);
			}
			base.set(c);

			try {
				transaction.run();
			} catch (RuntimeException e) {
				try {
					c.rollback();
					c.setAutoCommit(commit);
				} catch (SQLException ex) {
					throw new TransactionException("transaction rollback: " + ex.getMessage(), e);
				}
				throw e;
			}

			try {
				c.commit();
			} catch (SQLException e) {
				throw new TransactionException("transaction commit", e);
			}
			c.setAutoCommit(commit);
		} catch (SQLException e) {
			throw new DBOpenException(e);
		} finally {
			base.set(null);
		}
	}

	public boolean tx(Runnable transaction) {
		try {
			batch(transaction);
		} catch (Throwable e) {
			return false;
		}
		return true;
	}

	public void debug(Object o) {
		if (isDev) {
			System.out.println(o);
		}
	}

	/* Utility */
	public static Timestamp now() {
		return new Timestamp(System.currentTimeMillis());
	}

	static String parseKeyParameter(String name) {
		name = name.toLowerCase();
		if (name.endsWith(":")) {
			name = name.substring(0, name.length() - 1);
		}
		return name;
	}
}
